import torch
import numpy as np

from dada.optimizer import WDA, UGM, DADA, DoG, Prodigy

from dada.model import ModelRunner
from dada.model.norm import TorchNormModel
from dada.utils import Param, run_different_opts, plot_optimizers_result


class NormRunner(ModelRunner):

    def __init__(self, params):
        self.p_list = params['p_list']
        if self.p_list is None:
            mu = params['p']
            if mu is None:
                raise ValueError('mu is None')
            self.mu_list = [mu]

        super(NormRunner, self).__init__(params)

    def run(self, iterations, model_name, save_plot, plots_directory):
        params = [Param(names=["p", "d"], values=[p, self.vector_size]) for p in self.p_list]
        value_distances_per_param = {}
        d_estimation_error_per_param = {}

        optimizers = []

        for param in params:
            print(param)
            p = param.get_param("p")
            d = param.get_param("d")
            optimal_point = [0] * d
            init = torch.randn(d, requires_grad=True)
            d0 = np.linalg.norm(optimal_point - init.clone().detach().numpy())

            # Dual Averaging Method
            da_model = TorchNormModel(d, p, init_point=init)
            da_optimizer = WDA(da_model.params(), d0=d0)

            # GD With Line Search Method
            gd_line_search_model = TorchNormModel(d, p, init_point=init)
            gd_line_search_optimizer = UGM(gd_line_search_model.params())

            # DoG Method
            dog_model = TorchNormModel(d, p, init_point=init)
            dog_optimizer = DoG(dog_model.params())

            # Prodigy Method
            prodigy_model = TorchNormModel(d, p, init_point=init)
            prodigy_optimizer = Prodigy(prodigy_model.params(), lr=0.01)

            # DADA Method
            dada_model = TorchNormModel(d, p, init_point=init)
            dada_optimizer = DADA(dada_model.params())

            optimizers = [[da_optimizer, da_model],
                          [gd_line_search_optimizer, gd_line_search_model],
                          [dog_optimizer, dog_model],
                          [prodigy_optimizer, prodigy_model],
                          [dada_optimizer, dada_model]]

            d_estimation_error, value_distances = run_different_opts(optimizers, iterations, optimal_point, log_per=(iterations // 10))
            value_distances_per_param[param] = value_distances
            d_estimation_error_per_param[param] = d_estimation_error

        plot_optimizers_result(optimizers, params, value_distances_per_param, d_estimation_error_per_param,
                               model_name=model_name, save=save_plot, plots_directory=plots_directory, mark_every=(iterations // 10))
